import ssl
from typing import Callable

# for type annotations
import requests as _clean_requests
import requests.sessions as _clean_requests_sessions
from urllib3.connectionpool import HTTPSConnectionPool as _clean_HTTPSConnectionPool
from urllib3.connection import VerifiedHTTPSConnection as _clean_VerifiedHTTPSConnection
import urllib3.connection as _clean_urllib3_connection

from requests_ja3.decoder import Decoder
from requests_ja3.patcher_utils import _module_from_class, _wrap
from requests_ja3.ssl_utils import SSLUtils

class Patcher:
    @staticmethod
    def patch (src_requests_module: type (_clean_requests), target_ja3_str: str):
        target_ja3 = Decoder.decode (target_ja3_str)
        def ssl_wrap_socket_hook (*args, **kwargs):
            print (f"ssl_wrap_socket called with args {args} kwargs {kwargs}")
            ssl_context: ssl.SSLContext = kwargs ["ssl_context"]
            kwargs ["ciphers"] = SSLUtils.cipher_numbers_to_string (target_ja3 ["accepted_ciphers"])
            enabled_names = list (cipher ["name"] for cipher in ssl_context.get_ciphers ())
            all_names = SSLUtils.get_cipher_names ()
            target_names = list (map (lambda _int: all_names [_int], target_ja3 ["accepted_ciphers"]))
            for name in enabled_names:
                print (f"{name} ({name in target_names})")
            return args, kwargs
        Patcher._inner_patch (src_requests_module, ssl_wrap_socket_hook)
    @staticmethod
    def check (requests_module: type (_clean_requests), target_ja3_str: str):
        real_ja3_str = requests_module.get ("https://ja3er.com/json").json () ["ja3"]
        if real_ja3_str != target_ja3_str:
            basic_error_message = f"real ja3 {real_ja3_str} does not match target ja3 {target_ja3_str}"
            target_ja3 = Decoder.decode (target_ja3_str)
            real_ja3 = Decoder.decode (real_ja3_str)

            print ('-' * 10)
            for ja3_field_name, target_ja3_field_value in target_ja3.items ():
                real_ja3_field_value = real_ja3 [ja3_field_name]
                if target_ja3_field_value == real_ja3_field_value: print (f"Field {ja3_field_name} matches! ({target_ja3_field_value} in both)")
                else: print (f"Field {ja3_field_name} does not match (target {target_ja3_field_value}, real {real_ja3_field_value})")
            print ('-' * 10)

            raise Exception (basic_error_message)
    @staticmethod
    def _inner_patch (src_requests_module: type (_clean_requests), ssl_wrap_socket_hook: Callable):
        src_session_class = src_requests_module.Session
        src_session_class.request = _wrap (src_session_class.request, "Session.request")
        src_sessions_module: type (_clean_requests_sessions) = _module_from_class (src_session_class)
        src_httpadapter_class: _clean_requests_sessions.HTTPAdapter = src_sessions_module.HTTPAdapter
        def get_connection_hook (connection_pool: _clean_HTTPSConnectionPool):
            def _make_request_hook (connection: _clean_VerifiedHTTPSConnection, *args, **kwargs):
                src_connection_module: type (_clean_urllib3_connection) = _module_from_class (connection.__class__)
                src_connection_module.ssl_wrap_socket = _wrap (src_connection_module.ssl_wrap_socket, "urllib3.util.ssl_.ssl_wrap_socket", pre_hook = ssl_wrap_socket_hook)
                return (connection, *args), kwargs
            connection_pool._make_request = _wrap (connection_pool._make_request, "HTTPSConnectionPool._make_request", pre_hook = _make_request_hook)
            return connection_pool
        src_httpadapter_class.get_connection = _wrap (src_httpadapter_class.get_connection, "HTTPAdapter.get_connection", post_hook = get_connection_hook)
